# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest

import vllm.envs as envs
from vllm.profiler.gpu_profiler import WorkerProfiler


class ConcreteWorkerProfiler(WorkerProfiler):
    """
    A basic implementation of a worker profiler for testing purposes.
    """

    def __init__(self):
        self.start_call_count = 0
        self.stop_call_count = 0
        self.should_fail_start = False
        super().__init__()

    def _start(self) -> None:
        if self.should_fail_start:
            raise RuntimeError("Simulated start failure")
        self.start_call_count += 1

    def _stop(self) -> None:
        self.stop_call_count += 1


@pytest.fixture(autouse=True)
def reset_mocks():
    """Fixture to reset mocks and env variables before each test."""
    envs.VLLM_PROFILER_DELAY_ITERS = 0
    envs.VLLM_PROFILER_MAX_ITERS = 0


def test_immediate_start_stop():
    """Test standard start without delay."""
    profiler = ConcreteWorkerProfiler()

    profiler.start()
    assert profiler._running is True
    assert profiler._active is True
    assert profiler.start_call_count == 1

    profiler.stop()
    assert profiler._running is False
    assert profiler._active is False
    assert profiler.stop_call_count == 1


def test_delayed_start():
    """Test that profiler waits for N steps before actually starting."""
    envs.VLLM_PROFILER_DELAY_ITERS = 2
    profiler = ConcreteWorkerProfiler()

    # User requests start
    profiler.start()

    # Should be active (request accepted) but not running (waiting for delay)
    assert profiler._active is True
    assert profiler._running is False
    assert profiler.start_call_count == 0

    # Step 1
    profiler.step()
    assert profiler._running is False

    # Step 2 (Threshold reached)
    profiler.step()
    assert profiler._running is True
    assert profiler.start_call_count == 1


def test_max_iterations():
    """Test that profiler stops automatically after max iterations."""
    envs.VLLM_PROFILER_MAX_ITERS = 2
    profiler = ConcreteWorkerProfiler()

    profiler.start()
    assert profiler._running is True

    # Iteration 1
    profiler.step()  # profiling_count becomes 1
    assert profiler._running is True

    # Iteration 2
    profiler.step()  # profiling_count becomes 2
    assert profiler._running is True

    # Iteration 3 (Exceeds max)
    profiler.step()  # profiling_count becomes 3

    # Should have stopped now
    assert profiler._running is False
    assert profiler.stop_call_count == 1


def test_delayed_start_and_max_iters():
    """Test combined delayed start and max iterations."""
    envs.VLLM_PROFILER_DELAY_ITERS = 2
    envs.VLLM_PROFILER_MAX_ITERS = 2
    profiler = ConcreteWorkerProfiler()

    profiler.start()

    # Step 1
    profiler.step()
    assert profiler._running is False
    assert profiler._active is True

    # Step 2 (Starts now)
    profiler.step()
    assert profiler._profiling_for_iters == 1
    assert profiler._running is True
    assert profiler._active is True

    # Next iteration
    profiler.step()
    assert profiler._profiling_for_iters == 2
    assert profiler._running is True

    # Iteration 2 (exceeds max)
    profiler.step()

    # Should have stopped now
    assert profiler._running is False
    assert profiler.stop_call_count == 1


def test_idempotency():
    """Test that calling start/stop multiple times doesn't break logic."""
    profiler = ConcreteWorkerProfiler()

    # Double Start
    profiler.start()
    profiler.start()
    assert profiler.start_call_count == 1  # Should only start once

    # Double Stop
    profiler.stop()
    profiler.stop()
    assert profiler.stop_call_count == 1  # Should only stop once


def test_step_inactive():
    """Test that stepping while inactive does nothing."""
    envs.VLLM_PROFILER_DELAY_ITERS = 2
    profiler = ConcreteWorkerProfiler()

    # Not started yet
    profiler.step()
    profiler.step()

    # Even though we stepped 2 times, start shouldn't happen because active=False
    assert profiler.start_call_count == 0


def test_start_failure():
    """Test behavior when the underlying _start method raises exception."""
    profiler = ConcreteWorkerProfiler()
    profiler.should_fail_start = True

    profiler.start()

    # Exception caught in _call_start
    assert profiler._running is False  # Should not mark as running
    assert profiler._active is True  # Request is still considered active
    assert profiler.start_call_count == 0  # Logic failed inside start


def test_shutdown():
    """Test that shutdown calls stop only if running."""
    profiler = ConcreteWorkerProfiler()

    # Case 1: Not running
    profiler.shutdown()
    assert profiler.stop_call_count == 0

    # Case 2: Running
    profiler.start()
    profiler.shutdown()
    assert profiler.stop_call_count == 1


def test_mixed_delay_and_stop():
    """Test manual stop during the delay period."""
    envs.VLLM_PROFILER_DELAY_ITERS = 5
    profiler = ConcreteWorkerProfiler()

    profiler.start()
    profiler.step()
    profiler.step()

    # User cancels before delay finishes
    profiler.stop()
    assert profiler._active is False

    # Further steps should not trigger start
    profiler.step()
    profiler.step()
    profiler.step()

    assert profiler.start_call_count == 0
